# encoding: utf-8
# @Time:    :2025/2/5 21:38
import os.path
import random
import time

import numpy as np
import torch
import warnings
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model import Transformer
from model.LMConfig import LMConfig

warnings.filterwarnings('ignore')


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def init_model(lm_config):
    tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
    model_from = 1  # 1从权重，2用transformers

    if model_from == 1:
        moe_path = '_moe' if lm_config.use_moe else ''
        # ckp = f'./out/pretrain_{lm_config.dim}{moe_path}.pth'
        ckp = f'./out2/pretrain_epoch_2_loss_0.49748_back.pth'
        ckp = f'./out2/not_exists_checkpoint.pth'

        model = Transformer(lm_config)
        if os.path.exists(ckp):
            print("loading checkpoints...")
            state_dict = torch.load(ckp, map_location=device)

            # 处理不需要的前缀
            unwanted_prefix = '_orig_mod.'
            for k, v in list(state_dict.items()):
                if k.startswith(unwanted_prefix):
                    state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

            for k, v in list(state_dict.items()):
                if 'mask' in k:
                    del state_dict[k]

            # 加载到模型中
            model.load_state_dict(state_dict, strict=False)
    else:
        model = AutoModelForCausalLM.from_pretrained('minimind', trust_remote_code=True)
    model = model.to(device)

    print(f'模型参数: {count_parameters(model) / 1e6} 百万 = {count_parameters(model) / 1e9} B (Billion)')
    print(f'模型参数: {count_parameters(model) / 1024/1024} M = {count_parameters(model) / 1024/1024/1024} B (Billion)')
    return model, tokenizer


def setup_seed(seed):
    random.seed(seed)  # 设置 Python 的随机种子
    np.random.seed(seed)  # 设置 NumPy 的随机种子
    torch.manual_seed(seed)  # 设置 PyTorch 的随机种子
    torch.cuda.manual_seed(seed)  # 为当前 GPU 设置随机种子（如果有）
    torch.cuda.manual_seed_all(seed)  # 为所有 GPU 设置随机种子（如果有）
    torch.backends.cudnn.deterministic = True  # 确保每次返回的卷积算法是确定的
    torch.backends.cudnn.benchmark = False  # 关闭 cuDNN 的自动调优，避免不确定性


if __name__ == '__main__':
    """
    这边测试的其实是没有训练，直接进行推理的时候，模型的表现
    
    """
    out_dir = "out"
    start = ""
    temperature = 0.7
    top_k = 8
    setup_seed(1337)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using device:", device)
    dtype = "bfloat16"
    max_seq_len = 128
    lm_config = LMConfig()
    lm_config.max_seq_len = max_seq_len

    # --------------------------------------------

    model, tokenizer = init_model(lm_config)
    model.eval()
    answer_way = 1
    stream = True

    prompt_datas = [
        '椭圆和圆的区别',
        '中国关于马克思主义基本原理',
        '人类大脑的主要功能是',
        '万有引力是',
        '世界上人口最多的国家是',
        'DNA的全称是',
        '数学中π的值大约是',
        '世界上最高的山峰是',
        '太阳系中最大的行星是',
        '二氧化碳的化学分子式是',
        '地球上最大的动物是',
        '地球自转一圈大约需要',
        '杭州市的美食有',
        '江苏省的最好的大学',
    ]

    qa_index = 0
    while True:

        if answer_way == 1:
            # run generation
            prompt = input('用户：')
        else:
            if qa_index >= len(prompt_datas):
                break
            prompt = prompt_datas[qa_index]
            print('问题：', prompt)
            qa_index += 1
        start = time.time()
        prompt = tokenizer.bos_token + prompt
        x = tokenizer(prompt).data['input_ids']
        x = (torch.tensor(x, dtype=torch.long, device=device)[None, ...])

        with torch.no_grad():
            res_y = model.generate(x, tokenizer.eos_token_id, max_new_tokens=max_seq_len, temperature=temperature,
                                   top_k=top_k, stream=stream)
            print('回答：', end='')
            try:
                y = next(res_y)
            except StopIteration:
                print("No answer")
                continue

            history_idx = 0
            while y != None:
                answer = tokenizer.decode(y[0].tolist())
                if answer and answer[-1] == '�':
                    try:
                        y = next(res_y)
                    except:
                        break
                    continue
                # print(answer)
                if not len(answer):
                    try:
                        y = next(res_y)
                    except:
                        break
                    continue

                print(answer[history_idx:], end='', flush=True)
                try:
                    y = next(res_y)
                except:
                    break
                history_idx = len(answer)
                if not stream:
                    break

            print('\n')

        end = time.time()
        print(end - start, 's')


